{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_Z481aBjzDq3"
      },
      "source": [
        "Copyright 2021 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "executionInfo": {
          "elapsed": 0,
          "status": "ok",
          "timestamp": 1631062866650,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "Di9bklyXzJJb"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_EOAeIqOzQdT"
      },
      "source": [
        "### Composable Augmentation Encoding for Video Representation Learning (CATE)\n",
        "\n",
        "[arXiv](https://arxiv.org/abs/2104.00616)\n",
        "\n",
        "[Project page](https://sites.google.com/corp/brown.edu/cate-iccv2021/)\n",
        "\n",
        "This colab demonstrates how to load pretrained CATE models from hub modules and run inference on video frames. It also includes an example of nearest neighbor classification experiment on the UCF-101 dataset.\n",
        "\n",
        "The checkpoints are accessible in the following Google Cloud storage directories:\n",
        "\n",
        "  - gs://gresearch/cate-iccv2021/kinetics400/\n",
        "  \n",
        "  - gs://gresearch/cate-iccv2021/something_v1/\n",
        "\n",
        "  - gs://gresearch/cate-iccv2021/something_v2/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "executionInfo": {
          "elapsed": 0,
          "status": "ok",
          "timestamp": 1631062866650,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "hCzElj_Gg2Ly"
      },
      "outputs": [],
      "source": [
        "import re\n",
        "import numpy as np\n",
        "from sklearn.metrics.pairwise import cosine_similarity\n",
        "\n",
        "import tensorflow.compat.v1 as tf\n",
        "tf.disable_eager_execution()\n",
        "import tensorflow_hub as hub\n",
        "import tensorflow_datasets as tfds"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "executionInfo": {
          "elapsed": 1,
          "status": "ok",
          "timestamp": 1631062866651,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "HORumkOgvSQE"
      },
      "outputs": [],
      "source": [
        "#@title Define video preprocessing functions for inference.\n",
        "\n",
        "def sample_linspace_frames(frames, num_frames, num_windows):\n",
        "  total_frames = tf.shape(frames)[0]\n",
        "  sel_idx = tf.range(total_frames)\n",
        "  num_repeats = tf.to_int32(\n",
        "      tf.ceil(tf.div(tf.to_float(num_frames), tf.to_float(total_frames))))\n",
        "  sel_idx = tf.tile(sel_idx, [num_repeats])\n",
        "  total_frames = tf.maximum(total_frames, num_frames)\n",
        "  offsets = tf.linspace(0.0, tf.cast(total_frames - num_frames, tf.float32), num_windows)\n",
        "  offsets = tf.cast(offsets, tf.int32)\n",
        "  output_idx = []\n",
        "  for i in range(num_windows):\n",
        "    window_idx = tf.slice(sel_idx, [offsets[i]], [num_frames])\n",
        "    output_idx.append(window_idx)\n",
        "  output_idx = tf.concat(output_idx, axis=0)\n",
        "  return tf.gather(frames, output_idx)\n",
        "\n",
        "def _compute_crop_shape(\n",
        "    image_height, image_width, aspect_ratio, crop_proportion):\n",
        "  image_width_float = tf.cast(image_width, tf.float32)\n",
        "  image_height_float = tf.cast(image_height, tf.float32)\n",
        "\n",
        "  def _requested_aspect_ratio_wider_than_image():\n",
        "    crop_height = tf.cast(tf.rint(\n",
        "        crop_proportion / aspect_ratio * image_width_float), tf.int32)\n",
        "    crop_width = tf.cast(tf.rint(\n",
        "        crop_proportion * image_width_float), tf.int32)\n",
        "    return crop_height, crop_width\n",
        "\n",
        "  def _image_wider_than_requested_aspect_ratio():\n",
        "    crop_height = tf.cast(\n",
        "        tf.rint(crop_proportion * image_height_float), tf.int32)\n",
        "    crop_width = tf.cast(tf.rint(\n",
        "        crop_proportion * aspect_ratio *\n",
        "        image_height_float), tf.int32)\n",
        "    return crop_height, crop_width\n",
        "\n",
        "  return tf.cond(\n",
        "      aspect_ratio \u003e image_width_float / image_height_float,\n",
        "      _requested_aspect_ratio_wider_than_image,\n",
        "      _image_wider_than_requested_aspect_ratio) \n",
        "\n",
        "def center_crop(images, height, width, crop_proportion):\n",
        "  shape = tf.shape(images)\n",
        "  image_height = shape[1]\n",
        "  image_width = shape[2]\n",
        "  crop_height, crop_width = _compute_crop_shape(\n",
        "      image_height, image_width, height / width, crop_proportion)\n",
        "  offset_height = ((image_height - crop_height) + 1) // 2\n",
        "  offset_width = ((image_width - crop_width) + 1) // 2\n",
        "  images = tf.image.crop_to_bounding_box(\n",
        "      images, offset_height, offset_width, crop_height, crop_width)\n",
        "  images = tf.image.resize_bicubic(images, [height, width])\n",
        "  return images\n",
        "\n",
        "def preprocess_video(video, num_frames, height, width, num_windows):\n",
        "  video = sample_linspace_frames(video, num_frames, num_windows)\n",
        "  video = tf.image.convert_image_dtype(video, dtype=tf.float32)\n",
        "  video = center_crop(video, height, width, crop_proportion=0.875)\n",
        "  video = tf.clip_by_value(video, 0., 1.)\n",
        "  video = tf.reshape(video, [num_windows, num_frames, height, width, 3])\n",
        "  return video"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "executionInfo": {
          "elapsed": 1,
          "status": "ok",
          "timestamp": 1631062866651,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "qqB0bpIgrUmL"
      },
      "outputs": [],
      "source": [
        "#@title Load tfhub checkpoint from Cloud storage.\n",
        "\n",
        "hub_path = 'gs://gresearch/cate-iccv2021/kinetics400/'\n",
        "module = hub.Module(hub_path, trainable=False)\n",
        "sess = tf.Session()\n",
        "sess.run(tf.global_variables_initializer())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "executionInfo": {
          "elapsed": 0,
          "status": "ok",
          "timestamp": 1631062866677,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "ScOvuABertqB",
        "outputId": "02fd511c-4bac-4355-a940-5fc1a23b96f7"
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "9537\n",
            "3783\n"
          ]
        }
      ],
      "source": [
        "#@title Set up data loader for UCF-101.\n",
        "\n",
        "# Averages the representation of 8 sliding windows of 32 frames.\n",
        "num_windows = 8\n",
        "num_frames = 32\n",
        "dataset_name = 'ucf101/ucf101_1'\n",
        "\n",
        "train_dataset, train_info = tfds.load(dataset_name, split='train', with_info=True)\n",
        "num_train_examples = train_info.splits['train'].num_examples\n",
        "num_classes = train_info.features['label'].num_classes\n",
        "test_dataset, test_info = tfds.load(dataset_name, split='test', with_info=True)\n",
        "num_test_examples = test_info.splits['test'].num_examples\n",
        "\n",
        "def _preprocess(x):\n",
        "  x['video'] = preprocess_video(x['video'], num_frames, 224, 224, num_windows)\n",
        "  return x\n",
        "\n",
        "x_train = train_dataset.map(_preprocess).batch(1)\n",
        "x_train = tf.data.make_one_shot_iterator(x_train).get_next()\n",
        "x_test = test_dataset.map(_preprocess).batch(1)\n",
        "x_test = tf.data.make_one_shot_iterator(x_test).get_next()\n",
        "\n",
        "print(num_train_examples)\n",
        "print(num_test_examples)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "executionInfo": {
          "elapsed": 0,
          "status": "ok",
          "timestamp": 1631062866679,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "t0KgSTRTHIjh",
        "outputId": "c07f502d-9f3d-4678-9e28-92b7773b81dc"
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "0 out of 9537 examples processed.\n",
            "100 out of 9537 examples processed.\n",
            "200 out of 9537 examples processed.\n",
            "300 out of 9537 examples processed.\n",
            "400 out of 9537 examples processed.\n",
            "500 out of 9537 examples processed.\n",
            "600 out of 9537 examples processed.\n",
            "700 out of 9537 examples processed.\n",
            "800 out of 9537 examples processed.\n",
            "900 out of 9537 examples processed.\n",
            "1000 out of 9537 examples processed.\n",
            "1100 out of 9537 examples processed.\n",
            "1200 out of 9537 examples processed.\n",
            "1300 out of 9537 examples processed.\n",
            "1400 out of 9537 examples processed.\n",
            "1500 out of 9537 examples processed.\n",
            "1600 out of 9537 examples processed.\n",
            "1700 out of 9537 examples processed.\n",
            "1800 out of 9537 examples processed.\n",
            "1900 out of 9537 examples processed.\n",
            "2000 out of 9537 examples processed.\n",
            "2100 out of 9537 examples processed.\n",
            "2200 out of 9537 examples processed.\n",
            "2300 out of 9537 examples processed.\n",
            "2400 out of 9537 examples processed.\n",
            "2500 out of 9537 examples processed.\n",
            "2600 out of 9537 examples processed.\n",
            "2700 out of 9537 examples processed.\n",
            "2800 out of 9537 examples processed.\n",
            "2900 out of 9537 examples processed.\n",
            "3000 out of 9537 examples processed.\n",
            "3100 out of 9537 examples processed.\n",
            "3200 out of 9537 examples processed.\n",
            "3300 out of 9537 examples processed.\n",
            "3400 out of 9537 examples processed.\n",
            "3500 out of 9537 examples processed.\n",
            "3600 out of 9537 examples processed.\n",
            "3700 out of 9537 examples processed.\n",
            "3800 out of 9537 examples processed.\n",
            "3900 out of 9537 examples processed.\n",
            "4000 out of 9537 examples processed.\n",
            "4100 out of 9537 examples processed.\n",
            "4200 out of 9537 examples processed.\n",
            "4300 out of 9537 examples processed.\n",
            "4400 out of 9537 examples processed.\n",
            "4500 out of 9537 examples processed.\n",
            "4600 out of 9537 examples processed.\n",
            "4700 out of 9537 examples processed.\n",
            "4800 out of 9537 examples processed.\n",
            "4900 out of 9537 examples processed.\n",
            "5000 out of 9537 examples processed.\n",
            "5100 out of 9537 examples processed.\n",
            "5200 out of 9537 examples processed.\n",
            "5300 out of 9537 examples processed.\n",
            "5400 out of 9537 examples processed.\n",
            "5500 out of 9537 examples processed.\n",
            "5600 out of 9537 examples processed.\n",
            "5700 out of 9537 examples processed.\n",
            "5800 out of 9537 examples processed.\n",
            "5900 out of 9537 examples processed.\n",
            "6000 out of 9537 examples processed.\n",
            "6100 out of 9537 examples processed.\n",
            "6200 out of 9537 examples processed.\n",
            "6300 out of 9537 examples processed.\n",
            "6400 out of 9537 examples processed.\n",
            "6500 out of 9537 examples processed.\n",
            "6600 out of 9537 examples processed.\n",
            "6700 out of 9537 examples processed.\n",
            "6800 out of 9537 examples processed.\n",
            "6900 out of 9537 examples processed.\n",
            "7000 out of 9537 examples processed.\n",
            "7100 out of 9537 examples processed.\n",
            "7200 out of 9537 examples processed.\n",
            "7300 out of 9537 examples processed.\n",
            "7400 out of 9537 examples processed.\n",
            "7500 out of 9537 examples processed.\n",
            "7600 out of 9537 examples processed.\n",
            "7700 out of 9537 examples processed.\n",
            "7800 out of 9537 examples processed.\n",
            "7900 out of 9537 examples processed.\n",
            "8000 out of 9537 examples processed.\n",
            "8100 out of 9537 examples processed.\n",
            "8200 out of 9537 examples processed.\n",
            "8300 out of 9537 examples processed.\n",
            "8400 out of 9537 examples processed.\n",
            "8500 out of 9537 examples processed.\n",
            "8600 out of 9537 examples processed.\n",
            "8700 out of 9537 examples processed.\n",
            "8800 out of 9537 examples processed.\n",
            "8900 out of 9537 examples processed.\n",
            "9000 out of 9537 examples processed.\n",
            "9100 out of 9537 examples processed.\n",
            "9200 out of 9537 examples processed.\n",
            "9300 out of 9537 examples processed.\n",
            "9400 out of 9537 examples processed.\n",
            "9500 out of 9537 examples processed.\n"
          ]
        }
      ],
      "source": [
        "#@title Compute the final activations from a pretrained ResNet-3D-50 network.\n",
        "\n",
        "x_train['video'] = tf.reshape(x_train['video'],\n",
        "                              [num_windows, num_frames, 224, 224, 3])\n",
        "output_train = module(inputs=x_train['video'], signature='default', as_dict=True)\n",
        "\n",
        "train_data = []\n",
        "for i in range(num_train_examples):\n",
        "  video, label, hiddens = sess.run((x_train['video'], x_train['label'], output_train['hiddens']))\n",
        "  train_data.append((label, hiddens))\n",
        "  if i % 100 == 0:\n",
        "    print('%d out of %d examples processed.' % (i, num_train_examples))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "executionInfo": {
          "elapsed": 666396,
          "status": "ok",
          "timestamp": 1631063552380,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "4SFT7amsH2GN",
        "outputId": "f9583329-8dbb-4708-d8e2-33b5f44aeec7"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "0 out of 3783 examples processed.\n",
            "100 out of 3783 examples processed.\n",
            "200 out of 3783 examples processed.\n",
            "300 out of 3783 examples processed.\n",
            "400 out of 3783 examples processed.\n",
            "500 out of 3783 examples processed.\n",
            "600 out of 3783 examples processed.\n",
            "700 out of 3783 examples processed.\n",
            "800 out of 3783 examples processed.\n",
            "900 out of 3783 examples processed.\n",
            "1000 out of 3783 examples processed.\n",
            "1100 out of 3783 examples processed.\n",
            "1200 out of 3783 examples processed.\n",
            "1300 out of 3783 examples processed.\n",
            "1400 out of 3783 examples processed.\n",
            "1500 out of 3783 examples processed.\n",
            "1600 out of 3783 examples processed.\n",
            "1700 out of 3783 examples processed.\n",
            "1800 out of 3783 examples processed.\n",
            "1900 out of 3783 examples processed.\n",
            "2000 out of 3783 examples processed.\n",
            "2100 out of 3783 examples processed.\n",
            "2200 out of 3783 examples processed.\n",
            "2300 out of 3783 examples processed.\n",
            "2400 out of 3783 examples processed.\n",
            "2500 out of 3783 examples processed.\n",
            "2600 out of 3783 examples processed.\n",
            "2700 out of 3783 examples processed.\n",
            "2800 out of 3783 examples processed.\n",
            "2900 out of 3783 examples processed.\n",
            "3000 out of 3783 examples processed.\n",
            "3100 out of 3783 examples processed.\n",
            "3200 out of 3783 examples processed.\n",
            "3300 out of 3783 examples processed.\n",
            "3400 out of 3783 examples processed.\n",
            "3500 out of 3783 examples processed.\n",
            "3600 out of 3783 examples processed.\n",
            "3700 out of 3783 examples processed.\n"
          ]
        }
      ],
      "source": [
        "x_test['video'] = tf.reshape(x_test['video'],\n",
        "                             [num_windows, num_frames, 224, 224, 3])\n",
        "output_test = module(inputs=x_test['video'], signature='default', as_dict=True)\n",
        "\n",
        "test_data = []\n",
        "for i in range(num_test_examples):\n",
        "  video, label, hiddens = sess.run((x_test['video'], x_test['label'], output_test['hiddens']))\n",
        "  test_data.append((label, hiddens))\n",
        "  if i % 100 == 0:\n",
        "    print('%d out of %d examples processed.' % (i, num_test_examples))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "executionInfo": {
          "elapsed": 17145,
          "status": "ok",
          "timestamp": 1631063601755,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "aPM5U1HXiJDT",
        "outputId": "575e79b1-e6e8-46e4-fbaa-d0fe2da97274"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "top 1 recall: 0.548771\n",
            "top 5 recall: 0.683320\n",
            "top 10 recall: 0.750991\n",
            "top 20 recall: 0.822892\n",
            "top 50 recall: 0.898758\n"
          ]
        }
      ],
      "source": [
        "#@title Run nearest neighbor classification.\n",
        "\n",
        "# We follow the standard setup to evaluate nearest neighbor video retrieval:\n",
        "# For each video in the test set, we query its K nearest neighbors from the\n",
        "# training set. If any label of the retrieved training examples matches the\n",
        "# ground truth label of the query example, we deem it a correct match.\n",
        "\n",
        "def prepare_knn_data(d):\n",
        "  x = []\n",
        "  y = []\n",
        "  for i in range(len(d)):\n",
        "    x.append(d[i][1])\n",
        "    y.append(d[i][0])\n",
        "  x = np.concatenate(x, 0).reshape((-1, num_windows, 2048))\n",
        "  # Average pool the features from num_windows windows.\n",
        "  x = np.mean(x, 1)\n",
        "  # L2 normalize the features.\n",
        "  x /= np.linalg.norm(x, axis=1).reshape(-1, 1)\n",
        "  y = np.concatenate(y, 0).reshape((-1))\n",
        "  return x, y\n",
        "\n",
        "def compute_cosine_dist(x_train, x_test):\n",
        "  return 1 - cosine_similarity(x_test, x_train)\n",
        "\n",
        "def knn_evaluation(dist, y_train, y_test, k):\n",
        "  num_samples = dist.shape[0]\n",
        "  hit = 0\n",
        "  for i in range(num_samples):\n",
        "    pred = np.argsort(dist[i])\n",
        "    for j in range(k):\n",
        "      if y_train[pred[j]] == y_test[i]:\n",
        "        hit += 1\n",
        "        break\n",
        "  recall = hit / num_samples\n",
        "  print('top %d recall: %f' % (k, hit / num_samples))\n",
        "\n",
        "x_train, y_train = prepare_knn_data(train_data)\n",
        "x_test, y_test = prepare_knn_data(test_data)\n",
        "dist = compute_cosine_dist(x_train, x_test)\n",
        "knn_evaluation(dist, y_train, y_test, 1)\n",
        "knn_evaluation(dist, y_train, y_test, 5)\n",
        "knn_evaluation(dist, y_train, y_test, 10)\n",
        "knn_evaluation(dist, y_train, y_test, 20)\n",
        "knn_evaluation(dist, y_train, y_test, 50)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "Composable AugmenTation Encoding (CATE)",
      "provenance": [
        {
          "file_id": "1OIOtkIaFxSdjJLVLhwn85NQZRLvwdZ65",
          "timestamp": 1624749879665
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
